library(Hmisc)
library(data.table)
library(kableExtra)

generate_summary_stats <- function(dt, vars_to_summarize, weight_var = NULL, 
                                   group_var = NULL, 
                                   include_factors = FALSE) {
  
  dt <- as.data.table(dt)
  weight_var <- enquo(weight_var)
  group_var <- enquo(group_var)

  weight_var_name <- if (!quo_is_null(weight_var)) quo_name(weight_var) else NULL
  group_var_name <- if (!quo_is_null(group_var)) quo_name(group_var) else "" # NULL
  
  summary_stats <- lapply(vars_to_summarize, function(var) {
    if (is.factor(dt[[var]]) || is.character(dt[[var]])) {
      if (!include_factors) return(NULL)
      
      stats <- dt[, .(
        N = as.integer(.N),
        Levels = as.double(length(unique(get(var)))),
        MostFrequent = names(which.max(table(get(var)))),
        Proportion = as.double(max(prop.table(table(get(var)))))
      ), by = c(group_var_name)]
    } else {
      if (is.null(weight_var_name)) {
        stats <- dt[, .(
          N = as.integer(sum(!is.na(get(var)))),
          Mean = as.double(mean(get(var), na.rm = TRUE)),
          Var = as.double(var(get(var), na.rm = TRUE)),
          Min = as.double(min(get(var), na.rm = TRUE)),
          Q10 = as.double(quantile(get(var), probs = 0.1, na.rm = TRUE)),
          Q25 = as.double(quantile(get(var), probs = 0.25, na.rm = TRUE)),
          Median = as.double(median(get(var), na.rm = TRUE)),
          Q75 = as.double(quantile(get(var), probs = 0.75, na.rm = TRUE)),
          Q90 = as.double(quantile(get(var), probs = 0.9, na.rm = TRUE)),
          Max = as.double(max(get(var), na.rm = TRUE))
        ), by = c(group_var_name)]
      } else {
        stats <- dt[, .(
          N = as.integer(sum(!is.na(get(var)))),
          Mean = as.double(tryCatch(weighted.mean(get(var), w = get(weight_var_name), na.rm = TRUE), error = function(e) NA_real_)),
          Var = as.double(tryCatch(wtd.var(get(var), weights = get(weight_var_name), na.rm = TRUE), error = function(e) NA_real_)),
          Min = as.double(min(get(var), na.rm = TRUE)),
          Q10 = as.double(tryCatch(wtd.quantile(get(var), weights = get(weight_var_name), probs = 0.1, na.rm = TRUE), error = function(e) NA_real_)),
          Q25 = as.double(tryCatch(wtd.quantile(get(var), weights = get(weight_var_name), probs = 0.25, na.rm = TRUE), error = function(e) NA_real_)),
          Median = as.double(tryCatch(wtd.quantile(get(var), weights = get(weight_var_name), probs = 0.5, na.rm = TRUE), error = function(e) NA_real_)),
          Q75 = as.double(tryCatch(wtd.quantile(get(var), weights = get(weight_var_name), probs = 0.75, na.rm = TRUE), error = function(e) NA_real_)),
          Q90 = as.double(tryCatch(wtd.quantile(get(var), weights = get(weight_var_name), probs = 0.9, na.rm = TRUE), error = function(e) NA_real_)),
          Max = as.double(max(get(var), na.rm = TRUE))
        ), by = c(group_var_name)]
      }
    }
    
    if (!is.null(stats)) {
      stats[, Variable := var]
    }
    
    return(stats)
  })
  
  # Remove NULL entries and ensure all columns are of the same type
  summary_stats <- Filter(Negate(is.null), summary_stats)
  
  # Identify all unique columns across all data.tables
  all_columns <- unique(unlist(lapply(summary_stats, names)))
  
  # Ensure each data.table has all columns, filling with NA where necessary
  summary_stats <- lapply(summary_stats, function(dt) {
    for (col in all_columns) {
      if (!(col %in% names(dt))) {
        dt[, (col) := NA_real_]
      }
    }
    # Ensure all columns (except character and factor columns) are double
    for (col in names(dt)) {
      if (!is.character(dt[[col]]) && !is.factor(dt[[col]])) {
        set(dt, j = col, value = as.double(dt[[col]]))
      }
    }
    return(dt[, ..all_columns])
  })
  
  summary_table <- rbindlist(summary_stats, fill = TRUE, use.names = TRUE)
  
  if (!is.null(group_var_name) & group_var_name != "") {
    setnames(summary_table, group_var_name, "Group")
    if (is.factor(dt[[group_var_name]])) {
      summary_table[, Group := factor(Group, levels = levels(dt[[group_var_name]]), labels = levels(dt[[group_var_name]]))]
    }
  } else {
    summary_table[, Group := "All"]
  }
  
#   setnames(summary_table, placebo_var_name, "Placebo")
#   if (is.factor(dt[[placebo_var_name]])) {
#     summary_table[, Placebo := factor(Placebo, levels = levels(dt[[placebo_var_name]]), labels = levels(dt[[placebo_var_name]]))]
#   }
  
  # Calculate total N for each variable
  total_N <- summary_table[, .(Total_N = sum(N)), by = .(Variable)]
  
  # Add "All" group with total N
  all_stats <- summary_table[, lapply(.SD, function(x) if(is.numeric(x)) mean(x, na.rm = TRUE) else if (is.integer(x)) sum(x, na.rm = TRUE) else x[1]), 
                             by = .(Variable), 
                             .SDcols = setdiff(names(summary_table), c("Group", "Variable", "N"))]
  all_stats[, Group := "All"]
  
  # Merge total N into all_stats
  all_stats <- merge(all_stats, total_N, by = c("Variable"))
  setnames(all_stats, "Total_N", "N")
  
#   summary_table <- rbindlist(list(summary_table, all_stats), use.names = TRUE, fill = TRUE)
  
  # Determine columns dynamically based on the data
  all_cols <- names(summary_table)
  fixed_cols <- c("Group", "Variable")
  stat_cols <- setdiff(all_cols, fixed_cols)
  
  setcolorder(summary_table, c(fixed_cols, stat_cols))

  summary_table[, `:=`(Group = str_replace_all(Group, "_", " "), 
                    #    Placebo = str_replace_all(Placebo, "_", " "), 
                       Variable = str_replace_all(Variable, "_", " "))]
  
  return(summary_table)
}


generate_factor_summary <- function(dt, factor_vars, group_var, placebo_var) {
  dt <- as.data.table(dt)
  
  summary_list <- lapply(factor_vars, function(var) {
    full <- data.table()
    levels_group <- levels(dt[[group_var]])
    
    for (lvl in levels_group) {
      sub <- dt[get(group_var) == lvl]
      total <- sub[, .N]
      temp <- sub[, .N, by = c(placebo_var, var)]
      condition_total <- sub[, .N, by = placebo_var]
      temp[, Proportion := N / total]
      temp_wide <- dcast(temp, as.formula(paste(placebo_var, "~", var)), value.var = "Proportion")
      temp_wide[, Group := lvl]
      temp_wide <- condition_total[temp_wide, on = placebo_var]
      full <- rbindlist(list(full, temp_wide), use.names = TRUE, fill = TRUE)
    }
    
    # setnames(full, old = c(placebo_var, var), new = c("Placebo", "Level"))
    full[, Variable := var]
    return(full)
  })
  
  factor_summary_table <- rbindlist(summary_list, fill = TRUE, use.names = TRUE)
  return(factor_summary_table)
}

# Example usage:





generate_factor_summary <- function(dt, factor_vars, group_var, placebo_var) {
  dt <- as.data.table(dt)
  
  summary_list <- lapply(factor_vars, function(var) {
    full <- data.table()
    levels_group <- levels(dt[[group_var]])
    
    for (lvl in levels_group) {
      sub <- dt[get(group_var) == lvl]
      total <- sub[, .N]
      temp <- sub[, .N, by = c(placebo_var, var)]
      condition_total <- sub[, .N, by = placebo_var]
      temp[, Proportion := N / total]
      temp_wide <- dcast(temp, as.formula(paste(placebo_var, "~", var)), value.var = "Proportion")
      temp_wide[, Group := lvl]
      temp_wide <- condition_total[temp_wide, on = placebo_var]
      full <- rbindlist(list(full, temp_wide), use.names = TRUE, fill = TRUE)
    }
    
    # Rename columns
    setnames(full, old = placebo_var, new = "Treatment Status")
    setnames(full, old = "N", new = "Total")
    setnames(full, old = "Group", new = "Assigned Rumor")
    
    full[, Variable := var]
    return(full)
  })
  
  factor_summary_table <- rbindlist(summary_list, fill = TRUE, use.names = TRUE)
  
  # Reorder columns
  desired_order <- c("Treatment Status", "Total", "Assigned Rumor", "Variable")
  remaining_cols <- setdiff(names(factor_summary_table), desired_order)
  final_order <- c(desired_order, remaining_cols)
  
  setcolorder(factor_summary_table, final_order)
  
  return(factor_summary_table)
}


create_latex_table <- function(dt, filename, title = NULL, caption = NULL, label = NULL, note = NULL, digits = 3) {
  if (is.null(caption)) {
    caption <- title
  }
  if (is.null(label)) {
    label <- paste0("tab:", gsub("\\s+", "_", tolower(filename)))
  }
  
  # Format numeric columns
  numeric_cols <- sapply(dt, is.numeric)
  dt_formatted <- copy(dt)
  dt_formatted[, (names(numeric_cols)[numeric_cols]) := lapply(.SD, function(x) sprintf(paste0("%.", digits, "f"), x)), 
               .SDcols = names(numeric_cols)[numeric_cols]]
  # Format character and factor columns
    char_factor_cols <- sapply(dt, function(x) is.character(x) || is.factor(x))
    dt_formatted[, (names(char_factor_cols)[char_factor_cols]) := lapply(.SD, function(x) gsub("_", " ", x)), 
                 .SDcols = names(char_factor_cols)[char_factor_cols]]
  
  # Format 'N' column separately
  if ("N" %in% names(dt_formatted)) {
    dt_formatted[, N := format(as.numeric(N), big.mark = ",")]
  }
  
  latex_table <- kable(dt_formatted, format = "latex", booktabs = TRUE, 
                       caption = caption,
                       label = label,
                       align = c("l", "l", "l", rep("r", ncol(dt_formatted) - 3)),
                       escape = FALSE) %>%
    kable_styling(latex_options = c("hold_position")) %>%
    column_spec(1:2, bold = TRUE) %>%
    row_spec(0, bold = TRUE)
  
  if (!is.null(note)) {
    latex_table <- latex_table %>% 
      footnote(general = note, general_title = "Note:", footnote_as_chunk = TRUE, threeparttable = TRUE)
  }
  
  save_kable(latex_table, file = filename)
  
  cat("Table saved to:", filename, "\n")
  
  return(latex_table)
}



generate_summary_stats <- function(dt, vars_to_summarize, weight_var = NULL, include_factors = FALSE) {
  dt <- as.data.table(dt)
  weight_var <- enquo(weight_var)
  weight_var_name <- if (!quo_is_null(weight_var)) quo_name(weight_var) else NULL

  summary_stats <- lapply(vars_to_summarize, function(var) {
    if (is.factor(dt[[var]]) || is.character(dt[[var]])) {
      if (!include_factors) return(NULL)

      stats <- dt[, .(
        N = as.integer(.N),
        Levels = as.double(length(unique(get(var)))),
        MostFrequent = names(which.max(table(get(var)))),
        Proportion = as.double(max(prop.table(table(get(var)))))
      )]
    } else {
      if (is.null(weight_var_name)) {
        stats <- dt[, .(
          N = as.integer(sum(!is.na(get(var)))),
          Mean = as.double(mean(get(var), na.rm = TRUE)),
          Var = as.double(var(get(var), na.rm = TRUE)),
          Min = as.double(min(get(var), na.rm = TRUE)),
          Q10 = as.double(quantile(get(var), probs = 0.1, na.rm = TRUE)),
          Q25 = as.double(quantile(get(var), probs = 0.25, na.rm = TRUE)),
          Median = as.double(median(get(var), na.rm = TRUE)),
          Q75 = as.double(quantile(get(var), probs = 0.75, na.rm = TRUE)),
          Q90 = as.double(quantile(get(var), probs = 0.9, na.rm = TRUE)),
          Max = as.double(max(get(var), na.rm = TRUE))
        )]
      } else {
        stats <- dt[, .(
          N = as.integer(sum(!is.na(get(var)))),
          Mean = as.double(weighted.mean(get(var), w = get(weight_var_name), na.rm = TRUE)),
        #   Var = as.double(wtd.var(get(var), weights = get(weight_var_name), na.rm = TRUE)),
          Min = as.double(min(get(var), na.rm = TRUE)),
          Q10 = as.double(wtd.quantile(get(var), weights = get(weight_var_name), probs = 0.1, na.rm = TRUE)),
          Q25 = as.double(wtd.quantile(get(var), weights = get(weight_var_name), probs = 0.25, na.rm = TRUE)),
          Median = as.double(wtd.quantile(get(var), weights = get(weight_var_name), probs = 0.5, na.rm = TRUE)),
          Q75 = as.double(wtd.quantile(get(var), weights = get(weight_var_name), probs = 0.75, na.rm = TRUE)),
          Q90 = as.double(wtd.quantile(get(var), weights = get(weight_var_name), probs = 0.9, na.rm = TRUE)),
          Max = as.double(max(get(var), na.rm = TRUE))
        )]
      }
    }

    if (!is.null(stats)) {
      stats[, Variable := var]
    }

    return(stats)
  })

  summary_stats <- Filter(Negate(is.null), summary_stats)
  all_columns <- unique(unlist(lapply(summary_stats, names)))

  summary_stats <- lapply(summary_stats, function(dt) {
    for (col in all_columns) {
      if (!(col %in% names(dt))) {
        dt[, (col) := NA_real_]
      }
    }
    for (col in names(dt)) {
      if (!is.character(dt[[col]]) && !is.factor(dt[[col]])) {
        set(dt, j = col, value = as.double(dt[[col]]))
      }
    }
    return(dt[, ..all_columns])
  })

  summary_table <- rbindlist(summary_stats, fill = TRUE, use.names = TRUE)

  fixed_cols <- c("Variable")
  stat_cols <- setdiff(names(summary_table), fixed_cols)
  setcolorder(summary_table, c(fixed_cols, stat_cols))

  summary_table[, Variable := str_replace_all(Variable, "_", " ")]

  return(summary_table)
}


generate_factor_summary <- function(dt, factor_vars) {
  dt <- as.data.table(dt)
  
  summary_list <- lapply(factor_vars, function(var) {
    total <- dt[, .N]
    temp <- dt[, .(N = .N, Proportion = .N / total), by = var]
    temp_wide <- dcast(temp, 1 ~ get(var), value.var = "Proportion")
    temp_wide[, Variable := var]
    temp_wide[, N := total]
    return(temp_wide)
  })
  
  factor_summary_table <- rbindlist(summary_list, fill = TRUE, use.names = TRUE)
  setcolorder(factor_summary_table, c("Variable", "N", setdiff(names(factor_summary_table), c("Variable", "N"))))
  factor_summary_table[,`.` := NULL]
  return(factor_summary_table)
}
